// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "aiter_hip_common.h"
#include <hip/hip_runtime.h>
#include "moe_sorting_api.hpp"
#include <memory>


struct __attribute__((packed)) KernelArgs
{
    void *ptr_O;
    p2 _p0;
    void *ptr_X;
    p2 _p1;
    void *ptr_GU;
    p2 _p2;
    void *ptr_XC;
    p2 _p3;
    void *ptr_D;
    p2 _p4;
    void *ptr_XQ;
    p2 _p5;
    void *ptr_GUQ;
    p2 _p6;
    void *ptr_DQ;
    p2 _p7;
    void *ptr_SMQ;
    p2 _p8;
    void *ptr_STP;
    p2 _p9;
    void *ptr_SW;
    p2 _p10;
    void *ptr_SEP;
    p2 _p11;
    unsigned int dim;
    p3 _p12;
    unsigned int inter_dim;
    p3 _p13;
    unsigned int token_cnt;
    p3 _p14;
    unsigned int eprt_cnt;
    p3 _p15;
    unsigned int Xs;
    p3 _p16;
    unsigned int GUs;
    p3 _p17;
    unsigned int Ds;
    p3 _p18;
    unsigned int Os;
    p3 _p19;
    unsigned int eGUs;
    p3 _p20;
    unsigned int eDs;
    p3 _p21;
    unsigned int eGUQs;
    p3 _p22;
    unsigned int eDQs;
    p3 _p23;
    unsigned int eSMQs;
    p3 _p24;
    unsigned int topk;
    p3 _p25;
};


unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };

class FMoeKernel
{
private:
    std::unique_ptr<AiterAsmKernelFast> asm_kernel=nullptr;
    uint32_t sub_GU = 512;
    bool is_int4 = false;

public:
    FMoeKernel()
    {
        asm_kernel=std::make_unique<AiterAsmKernelFast>("{{kernel_name}}", hsaco);
        this->sub_GU = {{selected_tile}};
    };

    void set_int4(bool is_int4_)
    {
        is_int4 = is_int4_;
    }

    template <typename T, typename T_O, bool switchGxy = false>
    void launch_kernel(void* out,               // [num_tokens, dim]
                       void* input,             // [num_tokens, dim] M,K
                       void* w1,                // [num_experts, inter_dim, dim] N,K
                       void* w2,                // [num_experts, dim, inter_dim]
                       void* sorted_token_ids,  // [max_num_tokens_padded]
                       void* sorted_weight_buf, // [max_num_tokens_padded]
                       void* sorted_expert_ids, // [max_num_m_blocks]
                       void* num_valid_ids,     // [1]
                       int topk,           
                       int num_tokens,
                       int dim,
                       int inter_dim,
                       int max_num_m_blocks,
                       int num_experts,
                       int input_stride_0,
                       const hipStream_t stream,
                       std::optional<void*> input_dqn = std::nullopt,
                       std::optional<void*> w1_dqn = std::nullopt,
                       std::optional<void*> w2_dqn = std::nullopt,
                       std::optional<void*> w2_smooth_qnt = std::nullopt,
                       std::optional<int> w1_stride_0 = std::nullopt,
                       std::optional<int> w2_stride_0 = std::nullopt,
                       std::optional<int> w2_smooth_stride_0 = std::nullopt
    )
    {
        int inter_dim_ = is_int4 ? inter_dim * 8 : inter_dim;
        uint32_t sub_GU = this->sub_GU;
        uint32_t I_elemSize = sizeof(T);
        uint32_t O_elemSize = sizeof(T_O);

        int stride_X = input_stride_0 * I_elemSize;
        int stride_GU = dim * I_elemSize;
        int stride_D = inter_dim_ * I_elemSize;
        if (is_int4)
        {
            stride_GU /= 2;
            stride_D /= 2;
        }
        int stride_expert_GU = stride_GU * inter_dim_;
        int stride_expert_D = stride_D * dim;
        int stride_expert_GUDQN = w1_dqn.has_value() ? w1_stride_0.value() * sizeof(float) : 0;
        int stride_expert_DDQN = w2_dqn.has_value() ? w2_stride_0.value() * sizeof(float) : 0;
        int stride_expert_SMTDQN = inter_dim_ * sizeof(float);
        int stride_O = dim * O_elemSize;
        if (inter_dim_ * 2 == inter_dim)
        {
            stride_expert_GU *= 2;
            // stride_expert_GUDQN *= 2;
        }

        KernelArgs args;
        size_t arg_size = sizeof(args);
        args.ptr_O = out;
        args.ptr_X = input;
        args.ptr_GU = w1;
        args.ptr_XC = num_valid_ids;
        args.ptr_D = w2;
        if constexpr (std::is_same<T, uint8_t>::value)
        {
            args.ptr_XQ = input_dqn.value();
            args.ptr_GUQ = w1_dqn.value();
            args.ptr_DQ = w2_dqn.value();
            args.ptr_SMQ = w2_smooth_qnt.has_value() ? w2_smooth_qnt.value() : nullptr;
        }
        else
        {
            args.ptr_XQ = nullptr;
            args.ptr_GUQ = nullptr;
            args.ptr_DQ = nullptr;
            args.ptr_SMQ = nullptr;
        }
        args.ptr_STP = sorted_token_ids;
        args.ptr_SW = sorted_weight_buf;
        args.ptr_SEP = sorted_expert_ids;
        args.dim = dim;
        args.inter_dim = inter_dim_;
        args.token_cnt = num_tokens;
        args.eprt_cnt = num_experts;
        args.Xs = stride_X;
        args.GUs = stride_GU;
        args.Ds = stride_D;
        args.Os = stride_O;
        args.eGUs = stride_expert_GU;
        args.eDs = stride_expert_D;
        args.eGUQs = stride_expert_GUDQN;
        args.eDQs = stride_expert_DDQN;
        args.eSMQs = stride_expert_SMTDQN;
        args.topk = topk;

        int bdx = 256;
        int gdx = ((inter_dim_ + sub_GU - 1) / sub_GU);
        int gdy = max_num_m_blocks;
        int gdz = 1;

        if constexpr (switchGxy)
        {
            asm_kernel->launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream});
        }
        else
        {
            asm_kernel->launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream});
        }
    };
};


{% if input_dtype=="uint16_t" %}
extern "C" {
void {{func_name}}(void* out,               // [num_tokens, dim]
                   void* hidden_states,     // [num_tokens, dim] M,K
                   void* gate,              // [num_experts, inter_dim, dim] N,K
                   void* down,              // [num_experts, dim, inter_dim]
                   void* topk_weight,       // [num_tokens, topk]
                   void* topk_ids,          // [num_tokens, topk]
                   void* sorted_token_ids,  // [max_num_tokens_padded]
                   void* sorted_weight_buf, // [max_num_tokens_padded]
                   void* sorted_expert_ids, // [max_num_m_blocks]
                   void* num_valid_ids,     // [1]
                   int num_tokens,
                   int dim,
                   int inter_dim,
                   int topk,
                   int num_experts,
                   int max_num_m_blocks,
                   int out_bytes,
                   int hidden_states_stride_0,
                   void* expert_mask,
                   void* workspace,
                   void* stream);
                   
void moe_sorting_get_workspace_size(int num_tokens, int num_experts, int* workspace_size);
}


void moe_sorting_get_workspace_size(int num_tokens, int num_experts, int* workspace_size){
    *workspace_size = moe_sorting_get_workspace_size(num_tokens, num_experts);
}


void {{func_name}}(void* out,               // [num_tokens, dim]
                   void* hidden_states,     // [num_tokens, dim] M,K
                   void* gate,              // [num_experts, inter_dim, dim] N,K
                   void* down,              // [num_experts, dim, inter_dim]
                   void* topk_weight,       // [num_tokens, topk]
                   void* topk_ids,          // [num_tokens, topk]
                   void* sorted_token_ids,  // [max_num_tokens_padded]1
                   void* sorted_weight_buf, // [max_num_tokens_padded]
                   void* sorted_expert_ids, // [max_num_m_blocks]
                   void* num_valid_ids,     // [1]
                   int num_tokens,
                   int dim,
                   int inter_dim,
                   int topk,
                   int num_experts,
                   int max_num_m_blocks,
                   int out_bytes,
                   int hidden_states_stride_0,
                   void* expert_mask,
                   void* workspace,
                   void* stream){
    int workspace_size = moe_sorting_get_workspace_size(num_tokens, num_experts);
    if (workspace_size > 0 && workspace==nullptr)
    {   
        throw std::runtime_error("workspace is required for non-oneshot sorting");
    }
    moe_sorting({
                    "int32",                    // index_type
                    "fp32",                       // weight_type; // currently always float
                    expert_mask!=nullptr          // if mask experts as local expert
                },
                {topk_ids, topk_weight, expert_mask, sorted_token_ids, sorted_weight_buf, sorted_expert_ids, num_valid_ids, out, workspace, num_tokens, {{block_size}}, num_experts, topk, out_bytes}, {reinterpret_cast<hipStream_t>(stream)});

    static FMoeKernel impl{};

    impl.launch_kernel<{{input_dtype}}, {{output_dtype}}, {{switch_gxy}}>(out,
                                                hidden_states,
                                                gate,
                                                down,
                                                sorted_token_ids,
                                                sorted_weight_buf,
                                                sorted_expert_ids,
                                                num_valid_ids,
                                                topk,
                                                num_tokens,
                                                dim,
                                                inter_dim,
                                                max_num_m_blocks,
                                                num_experts,
                                                hidden_states_stride_0,
                                                reinterpret_cast<hipStream_t>(stream));
}
{% else %}
extern "C" {
void {{func_name}}(void* out,               // [num_tokens, dim]
                   void* hidden_states,     // [num_tokens, dim] M,K
                   void* gate,              // [num_experts, inter_dim, dim] N,K
                   void* down,              // [num_experts, dim, inter_dim]
                   void* topk_weight,
                   void* topk_ids,
                   void* sorted_token_ids,  // [max_num_tokens_padded]
                   void* sorted_weight_buf, // [max_num_tokens_padded]
                   void* sorted_expert_ids, // [max_num_m_blocks]
                   void* num_valid_ids,     // [1]
                   int num_tokens,
                   int dim,
                   int topk,                    //
                   int num_experts,
                   int out_bytes,
                   void* input_scale,       // [num_tokens, 1]
                   void* fc1_scale,         // [num_experts, 1, inter_dim]
                   void* fc2_scale,         // [num_experts, 1, dim]
                   void* fc2_smooth_scale,  // [num_experts, 1, inter_dim],
                   void* a8,                // [M, dim]
                   void* a8_scale,          // [M]
                   void* per_tensor_quant_scale,
                   void* expert_mask,
                   void* workspace,
                   const void* stream);
}

void {{func_name}}(void* out,               // [num_tokens, dim]
                   void* hidden_states,     // [num_tokens, dim] M,K
                   void* gate,              // [num_experts, inter_dim, dim] N,K
                   void* down,              // [num_experts, dim, inter_dim]
                   void* topk_weight,
                   void* topk_ids,
                   void* sorted_token_ids,  // [max_num_tokens_padded]
                   void* sorted_weight_buf, // [max_num_tokens_padded]
                   void* sorted_expert_ids, // [max_num_m_blocks]
                   void* num_valid_ids,     // [1]
                   int num_tokens,
                   int dim,
                   int topk,                    //
                   int num_experts,
                   int out_bytes,
                   void* input_scale,       // [num_tokens, 1]
                   void* fc1_scale,         // [num_experts, 1, inter_dim]
                   void* fc2_scale,         // [num_experts, 1, dim]
                   void* fc2_smooth_scale,  // [num_experts, 1, inter_dim],
                   void* a8,                // [M, dim]
                   void* a8_scale,          // [M]
                   void* per_tensor_quant_scale,
                   void* expert_mask,
                   void* workspace,
                   const void* stream){
    int workspace_size = moe_sorting_get_workspace_size(num_tokens, num_experts);
    if (workspace_size > 0 && workspace==nullptr)
    {
        throw std::runtime_error("workspace is required for non-oneshot sorting");
    }
    moe_sorting({
                    "int32",                    // index_type
                    "fp32",                       // weight_type; // currently always float
                    expert_mask!=nullptr          // if mask experts as local expert
                },
                {topk_ids, topk_weight, expert_mask, sorted_token_ids, sorted_weight_buf, sorted_expert_ids, num_valid_ids, out, workspace, num_tokens, {{block_size}}, num_experts, topk, out_bytes}, {reinterpret_cast<hipStream_t>(stream)});

    static FMoeKernel impl{};
    {% if input_dtype=="uint32_t" or input_dtype=="int32_t" %}
    impl->set_int4(true);
    {% endif %}
    impl.launch_kernel<{{input_dtype}}, {{output_dtype}}, {{switch_gxy}}>(out,
                                               hidden_states,
                                               gate,
                                               down,
                                               sorted_token_ids,
                                               sorted_weight_buf,
                                               sorted_expert_ids,
                                               num_valid_ids,
                                               topk,
                                               // quant args
                                               input_scale,
                                               fc1_scale,
                                               fc2_scale,
                                               fc2_smooth_scale);
}
{% endif %}
